{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "## VOC12\n",
    "VOC12_classes = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',\n",
    "               'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog',\n",
    "               'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa',\n",
    "               'train', 'tvmonitor']\n",
    "\n",
    "## COCO\n",
    "COCO_classes = ['person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',\n",
    "        'truck', 'boat', 'traffic light', 'fire hydrant', 'stop sign',\n",
    "        'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',\n",
    "        'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',\n",
    "        'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',\n",
    "        'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',\n",
    "        'surfboard', 'tennis racket', 'bottle', 'wine glass', 'cup', 'fork',\n",
    "        'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',\n",
    "        'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',\n",
    "        'couch', 'potted plant', 'bed', 'dining table', 'toilet', 'tv',\n",
    "        'laptop', 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',\n",
    "        'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',\n",
    "        'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'banner',\n",
    "        'blanket', 'branch', 'bridge', 'building-other', 'bush', 'cabinet',\n",
    "        'cage', 'cardboard', 'carpet', 'ceiling-other', 'ceiling-tile',\n",
    "        'cloth', 'clothes', 'clouds', 'counter', 'cupboard', 'curtain',\n",
    "        'desk-stuff', 'dirt', 'door-stuff', 'fence', 'floor-marble',\n",
    "        'floor-other', 'floor-stone', 'floor-tile', 'floor-wood',\n",
    "        'flower', 'fog', 'food-other', 'fruit', 'furniture-other', 'grass',\n",
    "        'gravel', 'ground-other', 'hill', 'house', 'leaves', 'light', 'mat',\n",
    "        'metal', 'mirror-stuff', 'moss', 'mountain', 'mud', 'napkin', 'net',\n",
    "        'paper', 'pavement', 'pillow', 'plant-other', 'plastic', 'platform',\n",
    "        'playingfield', 'railing', 'railroad', 'river', 'road', 'rock', 'roof',\n",
    "        'rug', 'salad', 'sand', 'sea', 'shelf', 'sky-other', 'skyscraper',\n",
    "        'snow', 'solid-other', 'stairs', 'stone', 'straw', 'structural-other',\n",
    "        'table', 'tent', 'textile-other', 'towel', 'tree', 'vegetable',\n",
    "        'wall-brick', 'wall-concrete', 'wall-other', 'wall-panel',\n",
    "        'wall-stone', 'wall-tile', 'wall-wood', 'water-other', 'waterdrops',\n",
    "        'window-blind', 'window-other', 'wood']\n",
    "\n",
    "templates = [\n",
    "    'a photo of a {}.',\n",
    "    'a photo of a small {}.',\n",
    "    'a photo of a medium {}.',\n",
    "    'a photo of a large {}.',\n",
    "    'This is a photo of a {}.',\n",
    "    'This is a photo of a small {}.',\n",
    "    'This is a photo of a medium {}.',\n",
    "    'This is a photo of a large {}.',\n",
    "    'a {} in the scene.',\n",
    "    'a photo of a {} in the scene.',\n",
    "    'There is a {} in the scene.',\n",
    "    'There is the {} in the scene.',\n",
    "    'This is a {} in the scene.',\n",
    "    'This is the {} in the scene.',\n",
    "    'This is one {} in the scene.',\n",
    "    ]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "### From CLIP https://colab.research.google.com/github/openai/clip\n",
    "import torch\n",
    "import numpy as np\n",
    "import clip\n",
    "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
    "model, preprocess = clip.load('ViT-B/16', device)\n",
    "\n",
    "## single template\n",
    "def single_templete(save_path, class_names, model):\n",
    "    with torch.no_grad():\n",
    "        texts = torch.cat([clip.tokenize(f\"a photo of a {c}\") for c in class_names]).cuda()\n",
    "        text_embeddings = model.encode_text(texts)\n",
    "        text_embeddings = text_embeddings / text_embeddings.norm(dim=-1, keepdim=True)\n",
    "        np.save(save_path, text_embeddings.cpu().numpy())\n",
    "    return text_embeddings\n",
    "\n",
    "## multi templates\n",
    "def multi_templete(save_path, class_names, model, templates):\n",
    "    with torch.no_grad():\n",
    "        text_embeddings = []\n",
    "        for classname in class_names:\n",
    "            texts = [template.format(classname) for template in templates] #format with class\n",
    "            texts = clip.tokenize(texts).cuda()\n",
    "            class_embeddings = model.encode_text(texts)\n",
    "            class_embeddings /= class_embeddings.norm(dim=-1, keepdim=True)\n",
    "            class_embedding = class_embeddings.mean(dim=0)\n",
    "            class_embedding /= class_embedding.norm()\n",
    "            text_embeddings.append(class_embedding)\n",
    "        text_embeddings = torch.stack(text_embeddings, dim=0).cuda()\n",
    "    np.save(save_path, text_embeddings.cpu().numpy())\n",
    "    return text_embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "## VOC12:\n",
    "save_path='./text_embeddings/voc12_single.npy'\n",
    "text_embeddings = single_templete(save_path, VOC12_classes, model)\n",
    "\n",
    "## COCO:\n",
    "save_path='./text_embeddings/coco_single.npy'\n",
    "text_embeddings = single_templete(save_path, COCO_classes, model)\n",
    "\n",
    "save_path='./text_embeddings/coco_multi.npy'\n",
    "text_embeddings = multi_templete(save_path, COCO_classes, model, templates)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.7.13 64-bit ('pt10': conda)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.13"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "d51e6962ef4cca746a608d0a67209086d135a3414651ca2ba27acebb6daa21dd"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
